# HETA_method.py
# Hessian-Enhanced Token Attribution (HETA) for decoder-only LMs (e.g., GPT-2/OPT/GPT-J).
# Components:
#  (1) Semantic gate M_T via attention rollout into the target position (last context token)
#  (2) Curvature sensitivity S_i via low-rank (Hutchinson) HVPs, optionally windowed
#  (3) Information contribution I_i via KL between original vs. embedding-masked inputs
# Final score: Attr_i = M_T[i] * ( beta * S_i + gamma * I_i )

import torch
import torch.nn.functional as F

@torch.no_grad()
def _attention_rollout_to_target(attentions, target_idx):
    """
    Compute Abnar & Zuidema-style attention rollout from every token to a single target index.
    attentions: list of L tensors [B, H, T, T] (post-softmax, causal)
    target_idx: int in [0, T-1] (we use T-1 for the next-token prediction)
    Returns M_T in shape [B, T] giving total flow to target, per token (keys).
    """
    # Mean heads per layer, add identity (skip) as in rollout paper, then multiply layers
    # We only need the single row that ends at target_idx.
    A = None
    for A_l in attentions:
        # A_l: [B,H,T,T] -> mean heads: [B,T,T]
        A_mean = A_l.mean(dim=1)
        # Add residual connection and renormalize row-wise
        eye = torch.eye(A_mean.size(-1), device=A_mean.device).unsqueeze(0)
        A_res = A_mean + eye
        A_res = A_res / A_res.sum(dim=-1, keepdim=True).clamp(min=1e-12)

        if A is None:
            A = A_res
        else:
            A = torch.bmm(A, A_res)  # [B,T,T]

    # Flow INTO target_idx: we need column target_idx of A^T, i.e., row target_idx of A (since we multiplied forward)
    # A[b, q, k]: attention flow from k to q after rollout; we need flow from token i (=k) to target (=q=target_idx).
    flow_to_target = A[:, target_idx, :]  # [B, T]
    # Normalize to simplex for gating
    flow_to_target = flow_to_target / flow_to_target.sum(dim=-1, keepdim=True).clamp(min=1e-12)
    return flow_to_target  # M_T

def _logprob_at_last(model, logits, input_ids):
    """
    Return log-prob distribution for the next token (first generated) given the full context.
    HF causal LMs produce logits for position t predicting token t+1, so the last row (-1) is the target distribution.
    """
    return F.log_softmax(logits[:, -1, :], dim=-1)  # [B, V]

def _mask_with_embedding(tokenizer, input_ids, embed, i, scheme="zero"):
    """
    Produce inputs_embeds with token i masked at the embedding level.
    scheme: "zero" or "mean"
    """
    X = embed(input_ids).clone()  # [B,T,d]
    if scheme == "zero":
        X[:, i, :].zero_()
    elif scheme == "mean":
        X[:, i, :].copy_(X.mean(dim=1, keepdim=True)[:, 0, :])
    else:
        raise ValueError("Unknown masking scheme")
    return X

def _hvp(fn, x, v):
    """
    Hessian-vector product for scalar fn(x). Uses autograd twice; returns H v with same shape as x.
    """
    (grad,) = torch.autograd.grad(fn, x, create_graph=True)
    hvp, = torch.autograd.grad((grad * v).sum(), x, retain_graph=False)
    return hvp

def heta_attribution(
    model,
    tokenizer,
    text,
    beta=0.5,             # weight on Hessian sensitivity (S)
    gamma=0.5,            # weight on KL information (I); we keep beta+gamma=1 in sweeps
    device="cuda",
    window_size=None,     # e.g., 512; if None, use full context
    hutchinson_m=8,       # number of Rademacher probes for low-rank curvature estimate
    mask_scheme="zero",   # "zero" or "mean" embedding masking for KL
):
    """
    Return tokens and HETA attributions for the first generated token (next-token distribution at last position).
    """
    assert 0 <= beta and 0 <= gamma, "beta,gamma must be nonnegative"
    model = model.to(device).eval()

    # Tokenize input; no BOS/EOS add to keep positions aligned
    enc = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(device)
    input_ids = enc["input_ids"]  # [1, T]
    T = input_ids.size(1)
    if T < 2:
        raise ValueError("Need at least 2 tokens to predict the next token.")

    # Optional windowing: restrict curvature & KL to a suffix window; keep gate on full context
    if window_size is None or window_size >= T:
        win_start = 0
    else:
        win_start = T - window_size
    win_slice = slice(win_start, T)  # tokens included in curvature/KL

    # ---------- Forward pass with attentions & (optionally) hidden states ----------
    with torch.no_grad():
        out = model(input_ids, output_attentions=True)
        logits = out.logits  # [1, T, V]
        attentions = list(out.attentions)  # length L, each [1,H,T,T]

    # Target distribution for the first generated token (the model would produce this next)
    logp_orig = _logprob_at_last(model, logits, input_ids)  # [1, V]

    # ---------- (1) Semantic gate M_T via rollout ending at target position (last context index = T-1) ----------
    M_T = _attention_rollout_to_target(attentions, target_idx=T - 1)[0]  # [T]
    # Numerical safety
    M_T = (M_T / M_T.sum().clamp(min=1e-12)).clamp(min=0.0)

    # ---------- (2) Curvature sensitivity S_i via low-rank HVPs (windowed) ----------
    # We compute S_i ≈ avg_k || Π_i H (Π_i r_k) ||_1, with Rademacher probes supported on token i's block.
    embed = model.get_input_embeddings()
    X = embed(input_ids).detach().requires_grad_(True)  # [1, T, d]
    d = X.size(-1)

    # Define scalar objective: log P(next token = argmax orig) or full-logZ scalar
    # For stability and target-conditioning, we use the log-prob of the *argmax* next token under the original distribution.
    target_id = logp_orig.argmax(dim=-1)  # [1]
    def scalar_obj(X_in):
        out = model(inputs_embeds=X_in)
        lp = F.log_softmax(out.logits[:, -1, :], dim=-1)
        return lp[0, target_id]  # scalar

    # Build probes only on the window tokens to reduce cost
    S = torch.zeros(T, device=device)
    # Pre-sample Rademacher probes for the window [win_start:T]
    probes = torch.randint_like(X[:, win_slice, :], low=0, high=2, dtype=torch.long).float()
    probes[probes == 0] = -1.0  # ±1
    # Repeat to reach m probes
    if hutchinson_m > probes.size(0):
        reps = (hutchinson_m + 0) // 1
        probes = probes.repeat(reps, 1, 1)[:hutchinson_m, :, :]  # [m, T_win, d]
    else:
        probes = probes[:hutchinson_m, :, :]

    # For each token i in window, estimate || Π_i H (Π_i r_k) ||_1 averaged over k
    for i in range(win_start, T):
        acc = 0.0
        for k in range(hutchinson_m):
            r = torch.zeros_like(X)
            r[0, i, :] = probes[k, i - win_start, :]
            hv = _hvp(lambda Xv: scalar_obj(Xv), X, r)  # [1,T,d]
            acc += hv[0, i, :].abs().sum()
        S[i] = acc / float(hutchinson_m)

    # ---------- (3) Information contribution I_i via KL (windowed) ----------
    I = torch.zeros(T, device=device)
    with torch.no_grad():
        for i in range(win_start, T):
            # Skip the last position if you wish (no future contribution), but keeping it is fine; mask then predict next.
            X_mask = _mask_with_embedding(tokenizer, input_ids, embed, i, scheme=mask_scheme)
            out_mask = model(inputs_embeds=X_mask)
            logp_mask = F.log_softmax(out_mask.logits[:, -1, :], dim=-1)  # [1,V]
            # KL(P || Q) with log-targets: use F.kl_div on logprobs (stable)
            I[i] = F.kl_div(logp_orig, logp_mask, reduction="sum", log_target=True)

    # ---------- Final attribution ----------
    # Normalize S and I to comparable scale per-instance (L1-normalize over the window), then gate with M_T
    def _safe_norm(x):
        s = x.sum().clamp(min=1e-12)
        return x / s

    Sn = _safe_norm(S.clamp(min=0))
    In = _safe_norm(I.clamp(min=0))
    Attr = M_T * (beta * Sn + gamma * In)

    # Optional: renormalize final scores to sum to 1 for readability
    Attr = Attr / Attr.sum().clamp(min=1e-12)

    # Decode tokens
    tokens = [tokenizer.decode([tid]) for tid in input_ids[0]]

    return tokens, Attr.detach().cpu().tolist()


if __name__ == "__main__":
    from transformers import AutoModelForCausalLM, AutoTokenizer
    model_name = "gpt2"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    text = "The Eiffel Tower is in Paris. The square root of sixteen is What is the value?"
    toks, attrs = heta_attribution(
        model, tokenizer, text,
        beta=0.5, gamma=0.5,
        device=device,
        window_size=256,
        hutchinson_m=8,
        mask_scheme="zero",
    )
    for t, a in zip(toks, attrs):
        print(f"{t!r}\t{a:.4f}")
